{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "28b2116b-d8db-4ecf-931b-6faf03950b26",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "device(type='cuda')"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import matplotlib.pyplot as plt\n",
    "import torchvision.transforms as transforms\n",
    "from PIL import Image\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "device"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cb810376-7bd1-4ce1-8e70-f67001d1f27b",
   "metadata": {},
   "source": [
    "### 加载模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "236d9a7f-674f-42c5-9a81-587ddd6aac18",
   "metadata": {},
   "source": [
    "加载生成器模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5dc72b2e-ec95-4b68-926f-2e02e2652be8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\PC\\AppData\\Local\\Temp\\ipykernel_26836\\4019217968.py:69: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  generator.load_state_dict(torch.load(model_path))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Generator(nn.Module):\n",
    "    def __init__(self, input_dim=100, output_dim=1, class_num=10):\n",
    "        '''\n",
    "        初始化生成网络\n",
    "        :param input_dim:输入随机噪声的维度，（随机噪声是为了增加输出多样性）\n",
    "        :param output_dim:生成图像的通道数（灰度图为1，RGB图为3）\n",
    "        :param class_num:图像种类\n",
    "        '''\n",
    "        super(Generator, self).__init__()\n",
    "        \"\"\"\n",
    "         为什么需要拼接随机噪声和条件向量？\n",
    "         拼接随机噪声和条件向量的目的是将两种信息结合起来，作为生成器的输入：\n",
    "         随机噪声：提供生成数据的随机性。\n",
    "         条件向量：提供生成数据的条件信息。\n",
    "         通过拼接，生成器可以根据条件向量生成符合特定条件的数据, 同时确保每次生成的数据会有所不同\n",
    "         \"\"\"\n",
    "        self.input_dim = input_dim\n",
    "        self.class_num = class_num\n",
    "        self.output_dim = output_dim\n",
    "        \n",
    "        # 嵌入层处理条件向量(类别标签), 提高条件信息的表达能力\n",
    "        self.label_emb = nn.Embedding(class_num, class_num)\n",
    "        \n",
    "        # 全连接层，将输入向量映射到高维空间，然后通过反卷积层生成图像\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(self.input_dim + self.class_num, 256),\n",
    "            nn.LeakyReLU(0.2, inplace=True),\n",
    "            nn.Linear(256, 512),\n",
    "            nn.LeakyReLU(0.2, inplace=True),\n",
    "            nn.Linear(512, 1024),\n",
    "            nn.LeakyReLU(0.2, inplace=True),\n",
    "            nn.Linear(1024, 128 * 7 * 7),\n",
    "            nn.BatchNorm1d(128 * 7 * 7),\n",
    "            nn.LeakyReLU(0.2, inplace=True)\n",
    "        )\n",
    "\n",
    "        # 反卷积层（转置卷积层），用于将高维特征图逐步上采样为最终图像\n",
    "        self.deconv = nn.Sequential(\n",
    "            nn.ConvTranspose2d(128, 128, 4, 2, 1),  # 7x7 -> 14x14\n",
    "            nn.BatchNorm2d(128),\n",
    "            nn.LeakyReLU(0.2, inplace=True),\n",
    "            nn.ConvTranspose2d(128, 64, 4, 2, 1),   # 14x14 -> 28x28\n",
    "            nn.BatchNorm2d(64),\n",
    "            nn.LeakyReLU(0.2, inplace=True),\n",
    "            nn.Conv2d(64, self.output_dim, 3, 1, 1),  # 保持尺寸不变，但细化特征\n",
    "            nn.Tanh()  # 激活函数，将输出值限制在 [-1, 1] 范围内，适合生成图像\n",
    "        )\n",
    " \n",
    "    def forward(self, noise, labels):\n",
    "        # 标签处理\n",
    "        label_embedding = self.label_emb(labels)\n",
    "        \n",
    "        # 拼接噪声和条件向量\n",
    "        x = torch.cat([noise, label_embedding], dim=1)\n",
    "        \n",
    "        # 通过全连接层\n",
    "        x = self.fc(x)\n",
    "        \n",
    "        # 重塑为特征图\n",
    "        x = x.view(-1, 128, 7, 7)\n",
    "        \n",
    "        # 通过反卷积层生成图像\n",
    "        x = self.deconv(x)\n",
    "        \n",
    "        return x\n",
    "\n",
    "generator = Generator().to(device)\n",
    "model_path = '../models/4_GAN_Image_Generator/MINIST_generator.pth'\n",
    "generator.load_state_dict(torch.load(model_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d02b634b-2977-4d6a-966b-39e63dcae20f",
   "metadata": {},
   "source": [
    "加载数字识别模型，用于辅助数字图像生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "54ae2abe-8062-4bee-8d14-28360709e7ec",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "C:\\Users\\PC\\AppData\\Local\\Temp\\ipykernel_26836\\3895384144.py:31: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  digit_recognizer.load_state_dict(torch.load(model_path))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "class Digit_recognizer(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        # （batch,1,28,28）\n",
    "        super(Digit_recognizer, self).__init__()\n",
    "        self.conv1 = torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(1, 32, kernel_size=3), #（batch,32,26,26）\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.MaxPool2d(kernel_size=2), #（batch,32,13,13）\n",
    "        )\n",
    "        self.conv2 = torch.nn.Sequential(\n",
    "            torch.nn.Conv2d(32, 64, kernel_size=3), #（batch,64,11,11）\n",
    "            torch.nn.ReLU(),\n",
    "            torch.nn.MaxPool2d(kernel_size=2), #（batch,64,5,5）\n",
    "        )\n",
    "        self.fc = torch.nn.Sequential(\n",
    "            torch.nn.Linear(1600, 50),\n",
    "            # torch.nn.ReLU(),  # 添加ReLU激活函数\n",
    "            torch.nn.Linear(50, 10),\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        batch_size = x.size(0)\n",
    "        x = self.conv1(x)  # 一层卷积层,一层池化层,一层激活层\n",
    "        x = self.conv2(x)  # 再来一次\n",
    "        x = x.view(batch_size, -1)  # flatten 变成全连接网络需要的输入\n",
    "        x = self.fc(x)\n",
    "        return x  # 最后输出的是维度为10的，也就是（对应数学符号的0~9）\n",
    "\n",
    "digit_recognizer = Digit_recognizer().to(device)\n",
    "model_path = '../models/1_Handwritten_Digit_Recognition/model_weights.pth'\n",
    "digit_recognizer.load_state_dict(torch.load(model_path))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "992d8fda-ac50-441a-95b4-e4500b25368f",
   "metadata": {},
   "source": [
    "### 手写数字图像生成"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "810ace5f-ffa4-4dd6-9ff9-b88db3571ea8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/jpeg": "/9j/4AAQSkZJRgABAQAAAQABAAD/2wBDAAgGBgcGBQgHBwcJCQgKDBQNDAsLDBkSEw8UHRofHh0aHBwgJC4nICIsIxwcKDcpLDAxNDQ0Hyc5PTgyPC4zNDL/wAALCAAcABwBAREA/8QAHwAAAQUBAQEBAQEAAAAAAAAAAAECAwQFBgcICQoL/8QAtRAAAgEDAwIEAwUFBAQAAAF9AQIDAAQRBRIhMUEGE1FhByJxFDKBkaEII0KxwRVS0fAkM2JyggkKFhcYGRolJicoKSo0NTY3ODk6Q0RFRkdISUpTVFVWV1hZWmNkZWZnaGlqc3R1dnd4eXqDhIWGh4iJipKTlJWWl5iZmqKjpKWmp6ipqrKztLW2t7i5usLDxMXGx8jJytLT1NXW19jZ2uHi4+Tl5ufo6erx8vP09fb3+Pn6/9oACAEBAAA/APn+uy8FfDPX/HlvdXGlfZooLZgjS3TMisx52qQpyQME+mR61peJvgv4p8L6Rc6pcvp9zaWyhpWtpySozjOGVfb8687rp/AXhCTxp4lSwM8dvaQJ9pvZnfbsgVlDkHBG75uM8evFb/i34mzv/wASTwYX0Xw3a/JCltmOSfHV3b73P1ye+TU/wp1aTU9X1fw/q+pTCw1XTbhJJJ5CyROF3eY244GAG5OPrXmVd38KdY0bTPEl/Z69O9tp+r6bNpr3C8CLzCvJPYYUjODgkZ4zWxJ8JtAEqCL4neGmi3HezTorBexA3nJ9sj6mqnivWfDvhvwunhLwjcx30k536rq6xbWuCCcRqT0T/dOMdzls+b0UUUV//9k=",
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAABSElEQVR4AWNgZMAJ8Ejh1EOMBLK5jIwsvxlZmX/8w6aRZ+nr249//3y9YJ0rJ4a88te//4Dg799/f6LBkkwIJSxXORj////z/dXf/4yiYGEWhGQv6//vH6/c2c3uECIgyfgfIcHAwP7o759bKfysjIyM3Jlfr3Ehy3F9/Pe3jgXqdvVX99SQJBlv//tXAfcW7+PvsUiSyr9+zUDirvv7FM5jbv/7dw2cB2Sk/fvLBuOzfP73BcYG00J/f4BosD85OBlKUSRNvqnC+Rx//iH5l4Fh5vZmuBwD179/cJcyMLDd/vuRGSHJ8u+fLoL34u8HQQSPgeHfvw8wrYwL/v6shnHAas78+/dpzeproV6dZ77++TKHHVkjA8s7UFxBwN/zMiiuY2CQ2fn5y8fvP3//erPaHjOimYEeZmRkYkSKXxSzKeeg+AO3cQDHXn1ioQLH8QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<PIL.Image.Image image mode=L size=28x28>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def generate_digit_image(generator, digit):\n",
    "    \"\"\"\n",
    "    生成指定数字的图片\n",
    "    :param generator: 训练好的生成器模型\n",
    "    :param digit: 要生成的数字 (0-9)\n",
    "    :return: 生成的图片 (PIL 图像)\n",
    "    \"\"\"\n",
    "    generator.eval()  # 设置为评估模式\n",
    "    with torch.no_grad():\n",
    "        # 生成随机噪声\n",
    "        noise = torch.randn(1, generator.input_dim).to(device)\n",
    "        \n",
    "        # 创建标签\n",
    "        label = torch.tensor([digit]).to(device)\n",
    "        \n",
    "        # 生成图片\n",
    "        fake_image = generator(noise, label)\n",
    "        \n",
    "        # 将图片从 [-1, 1] 转换到 [0, 1]\n",
    "        fake_image = (fake_image.squeeze().cpu() + 1) / 2.0\n",
    "        \n",
    "        # 将 2D 张量 (H, W) 转换为 3D 张量 (1, H, W)\n",
    "        fake_image = fake_image.unsqueeze(0)\n",
    "        \n",
    "        # 转换为 PIL 图像\n",
    "        fake_image = transforms.ToPILImage()(fake_image)\n",
    "        \n",
    "        return fake_image\n",
    "\n",
    "# 预测数字\n",
    "def predict_image(digit_recognizer, image):\n",
    "    # 图像预处理\n",
    "    transform = transforms.Compose([\n",
    "        transforms.Grayscale(num_output_channels=1),  # 转换为灰度\n",
    "        transforms.Resize((28, 28)),                 # 调整到 28x28\n",
    "        transforms.ToTensor(),                       # 转换为张量\n",
    "        transforms.Normalize((0.5,), (0.5,))         # 归一化到 [-1, 1]\n",
    "    ])\n",
    "    image = transform(image)\n",
    "    image = image.to(device)\n",
    "    image = image.unsqueeze(0)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        output = digit_recognizer(image)\n",
    "        _, predicted = torch.max(output.data, 1)\n",
    "    return predicted.item()\n",
    "\n",
    "# 利用数字识别模型确保生成的数字图片更清晰\n",
    "def generate_high_quality_digit_image(generator, digit_recognizer, digit):\n",
    "    while True:\n",
    "        digit_image = generate_digit_image(generator, digit)\n",
    "        predict_label = predict_image(digit_recognizer, digit_image)\n",
    "        if int(predict_label) == digit:\n",
    "            return digit_image\n",
    "\n",
    "generate_high_quality_digit_image(generator, digit_recognizer, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1f5562b0-1932-4d67-b264-1bb4755c8b4c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "图片已保存到 ./data/demo.png\n"
     ]
    }
   ],
   "source": [
    "# 生成一个包含 10x10 个不同数字的大图片，并保存到本地\n",
    "plt.figure(figsize=(10, 10))  # 设置画布大小\n",
    "plt.subplots_adjust(wspace=0.1, hspace=0.1)  # 调整子图间距\n",
    "\n",
    "for i in range(10):  # 行\n",
    "    for j in range(10):  # 列\n",
    "        # 生成数字 j 的图片\n",
    "        digit_image = generate_high_quality_digit_image(generator, digit_recognizer, j)\n",
    "        \n",
    "        # 将图片添加到子图中\n",
    "        ax = plt.subplot(10, 10, i * 10 + j + 1)\n",
    "        ax.imshow(digit_image, cmap='gray')\n",
    "        ax.axis('off')  # 关闭坐标轴\n",
    "\n",
    "save_path = './data/demo.png'\n",
    "# 保存大图片\n",
    "plt.savefig(save_path, bbox_inches='tight')\n",
    "plt.close()\n",
    "print(f\"图片已保存到 {save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdca719b-16f7-4557-a5a2-3294f8377138",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.20"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
